from Network.network_utils import run_optimizer, pytorch_model
from tianshou.data import Batch
import numpy as np
import torch

def infer_gradient(args, params, model, batch, keep_all=False):
    form = "all" if len(args.infer.infer_names) == 0 else "full"
    if form == "full":
        result = Batch()
        for name in args.infer.infer_names:
            model.set_target_name(name)
            result[name] = compute_gradient_vals(args, params, model, batch, form, keep_all=keep_all)
    else:
        result = compute_gradient_vals(args, params, model, batch, form, keep_all=keep_all)
    return result

def compute_gradient_vals(args, params, model, batch, form, keep_all=False, result=None):
    if result: infer_result = result
    else: infer_result = model.infer(batch, batch.valid, [form], grad_settings=["input"], keep_all=keep_all)
    result = Batch()
    # print(list(result.keys()), list(result[form].keys()))
    grad_variables = [infer_result[form].full_active_input] # , result.full_active_embed # TODO: only uses input gradients
    compute_models, optims = model.get_model_optim([form])
    optim, compute_model = optims[0], compute_models[0]
    result.gradients = run_optimizer(optim, compute_model, - infer_result[form].log_probs, grad_variables=grad_variables, no_step = True)
    # print(result.gradients.shape, result[form].full_active_input.shape, batch.obs.shape)
    # print(len(batch), model.extractor.get_index([model.target_name]), args.factor.num_objects, -1, result.gradients[0,:,args.factor.first_obj_dim:].shape)
    # TODO: logic might be wrong for all
    if form == "all": result.grads = result.gradients[0].reshape(result.gradients[0].shape[0], len(model.extractor.get_index([model.target_name])), args.factor.num_objects, -1)
    else: result.grads = result.gradients[0,:,args.factor.single_obj_dim:].reshape(result.gradients[0].shape[0], len(model.extractor.get_index([model.target_name])), args.factor.num_objects, -1) # reshapes into the gradients per input
    result.mask_logits = result.grads.abs().sum(dim=-1)
    result.omit_flags = infer_result[form].omit_flags
    if keep_all: result.mask_logits = result.mask_logits * (1-batch.done)
    result.trace = batch.trace[infer_result[form].omit_flags[0]]
    result.utrace = result.trace if form == "all" else result.trace[:, model.extractor.get_index([model.target_name])]
    if len(result.utrace.shape) == 2: result.utrace = np.expand_dims(result.utrace, axis=1)
    result.inter_one_trace_rate = np.expand_dims(np.array([min(1, np.mean(pytorch_model.unwrap(result.mask_logits[...,i][result.utrace[...,i] == 1]))) for i in range(result.mask_logits.shape[-1])]), axis=0)
    result.inter_zero_trace_rate = np.expand_dims(np.array([min(1, np.mean(pytorch_model.unwrap(result.mask_logits[...,i][result.utrace[...,i] == 0]))) for i in range(result.mask_logits.shape[-1])]), axis=0)
    if args.infer.gradient.select_ideal: 
        midpoint = (result.inter_one_trace_rate + result.inter_zero_trace_rate) / 2
        result.inter_masks = (pytorch_model.unwrap(result.mask_logits) > midpoint).astype(int)        
    else:
        result.inter_masks = (pytorch_model.unwrap(result.mask_logits) > args.infer.gradient.gradient_threshold).astype(int)
    result.bin_error = pytorch_model.unwrap(result.inter_masks) - result.utrace # assume only one target
    result.total_error = np.abs(pytorch_model.unwrap(result.inter_masks) - result.utrace) # assume only one target
    if keep_all: result.bin_error = result.bin_error * (1-batch.done)
    # print(np.concatenate([result.utrace, pytorch_model.unwrap(result.mask_logits)], axis=-1)[:20].squeeze())
    # TODO: print out the inter masks for trace and not-trace
    return result

def compute_gradient_loss(model, batch, args, params, results, form, keep_all = False):
    # if results is none will compute using infer
    if type(results) == tuple:
        results = results[0] # just use the first one for embedding losses, if multiple
    results = compute_gradient_vals(args, params, model, batch, form,name="",  keep_all=False, results=None)
    # print(input_weights)
    magnitude_loss = torch.linalg.norm(results.mask_logits, ord=args.inter.regularizers.norm_form, dim=-1)
    return magnitude_loss
